import jieba
import torch
import pickle
from Seq2SeqModel import *
from beam import *

sos_token =0
eos_token =1
MAX_LEN = 7

def evaluatoin_beamsearch_heapq(encoder_outputs, encoder_hidden,DecoderModel,output_lang):
    """使用 堆 来完成beam search，对是一种优先级的队列，按照优先级顺序存取数据"""

    batch_size = encoder_hidden.size(0)
    # 1. 构造第一次需要的输入数据，保存在堆中
    decoder_input = torch.LongTensor([sos_token] * 1)
    decoder_hidden = encoder_hidden  # 需要输入的hidden

    prev_beam = Beam()
    prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden)
    while True:
        cur_beam = Beam()
        # 2. 取出堆中的数据，进行forward_step的操作，获得当前时间步的output，hidden
        # 这里使用下划线进行区分
        for _probility, _complete, _seq, _decoder_input, _decoder_hidden in prev_beam:
            # 判断前一次的_complete是否为True，如果是，则不需要forward
            # 有可能为True，但是概率并不是最大
            if _complete == True:
                cur_beam.add(_probility, _complete, _seq, _decoder_input, _decoder_hidden)
            else:
                decoder_output_t, decoder_hidden, _ = DecoderModel(_decoder_input, _decoder_hidden,
                                                                        encoder_outputs)

                decoder_output_t.view(-1)
                value, index = torch.topk(decoder_output_t.squeeze(0), 3)  # [batch_size=1,beam_widht=3]
                # 3. 从output中选择topk（k=beam width）个输出，作为下一次的input

                for m, n in zip(value[0], index[0]):

                    decoder_input = torch.LongTensor([[n]])
                    seq = _seq + [n.item()]
                    probility = _probility * m
                    if n.item() == eos_token:
                        complete = True
                    else:
                        complete = False

                        # 4. 把下一个实践步骤需要的输入等数据保存在一个新的堆中
                    cur_beam.add(probility, complete, seq,
                                     decoder_input, decoder_hidden)
        # 5. 获取新的堆中的优先级最高（概率最大）的数据，判断数据是否是EOS结尾或者是否达到最大长度，如果是，停止迭代
        best_prob, best_complete, best_seq, _, _ = max(cur_beam)
        if best_complete == True or len(best_seq)  == 7:  # 减去sos
            return decode(output_lang,best_seq[1:-1])
        else:
            # 6. 则重新遍历新的堆中的数据
            prev_beam = cur_beam

def words_tensor(words,lang):
    id = [lang.word2index.get(word,3) for word in words]
    if len(id)>MAX_LEN:
        id = id[:MAX_LEN]
    else:
        id = id + [2]*(MAX_LEN-len(id))
    return torch.LongTensor(id)

with open("dict/input_lang.pkl","rb") as f:
    input_lang = pickle.load(f)
with open("dict/out_lang.pkl","rb") as f:
    out_lang = pickle.load(f)



EncoderModel = Encoder(input_lang.n_words,hidden_size=100)
EncoderModel.load_state_dict(torch.load("savemode/EncoderModel.pkl"))
EncoderModel.eval()

DecoderModel = AttentionDencoder(output_size=out_lang.n_words, hidden_size=100)
DecoderModel.load_state_dict(torch.load("savemode/DecoderModel.pkl"))
DecoderModel.eval()

def decode(lang,id_len):
    sentenc = ""
    for i in id_len:
        sentenc += lang.index2word[i]
    return sentenc


def predict(sentence,input_lang,output_lang,EncoderModel,DecoderModel):
    input_words = jieba.lcut(sentence)
    input = words_tensor(input_words,input_lang)


    encoder_output,hidden = EncoderModel(input.unsqueeze(0),None)

    decoder_input = torch.tensor([sos_token] * 1, device=device)
    output_id = []

    print("A:",evaluatoin_beamsearch_heapq(encoder_output, hidden,DecoderModel,output_lang))

    # for i in range(MAX_LEN):
    #     output, hidden, attn_weights = DecoderModel(decoder_input,hidden,encoder_output)
    #
    #
    #
    #     output = output.view(-1)
    #
    #     _, id = output.topk(1)
    #     if id == 1:
    #         break
    #     output_id.append(id.item())
    #     decoder_input = id.view(-1)
    #
    #
    # print(output_id)
    # out = decode(output_lang,output_id)
    # print(out)
    # return out

while True:
    sentence = input("Q：")

    predict(sentence,input_lang,out_lang,EncoderModel,DecoderModel)




# decoder中的新方法











